Skip to content

Conversation

@XChy
Copy link
Member

@XChy XChy commented Sep 1, 2025

Address TODO and implement constant fold for intermediate multiplication result of vpmadd52l/vpmadd52h.

@llvmbot
Copy link
Member

llvmbot commented Sep 1, 2025

@llvm/pr-subscribers-backend-x86

Author: XChy (XChy)

Changes

Address TODO and implement constant fold for intermediate multiplication result of vpmadd52l/vpmadd52h.


Full diff: https://github.com/llvm/llvm-project/pull/156293.diff

2 Files Affected:

  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+23-10)
  • (modified) llvm/test/CodeGen/X86/combine-vpmadd52.ll (+107)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index d78cf00a5a2fc..840c2730625c0 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -44954,26 +44954,39 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
   }
   case X86ISD::VPMADD52L:
   case X86ISD::VPMADD52H: {
-    KnownBits KnownOp0, KnownOp1;
+    KnownBits Known52BitsOfOp0, Known52BitsOfOp1;
     SDValue Op0 = Op.getOperand(0);
     SDValue Op1 = Op.getOperand(1);
     SDValue Op2 = Op.getOperand(2);
     //  Only demand the lower 52-bits of operands 0 / 1 (and all 64-bits of
     //  operand 2).
     APInt Low52Bits = APInt::getLowBitsSet(BitWidth, 52);
-    if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts, KnownOp0,
-                             TLO, Depth + 1))
+    if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts,
+                             Known52BitsOfOp0, TLO, Depth + 1))
       return true;
 
-    if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts, KnownOp1,
-                             TLO, Depth + 1))
+    if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts,
+                             Known52BitsOfOp1, TLO, Depth + 1))
       return true;
 
-    // X * 0 + Y --> Y
-    // TODO: Handle cases where lower/higher 52 of bits of Op0 * Op1 are known
-    // zeroes.
-    if (KnownOp0.trunc(52).isZero() || KnownOp1.trunc(52).isZero())
-      return TLO.CombineTo(Op, Op2);
+    KnownBits KnownMul;
+    Known52BitsOfOp0 = Known52BitsOfOp0.trunc(52);
+    Known52BitsOfOp1 = Known52BitsOfOp1.trunc(52);
+    if (Opc == X86ISD::VPMADD52L) {
+      KnownMul =
+          KnownBits::mul(Known52BitsOfOp0.zext(104), Known52BitsOfOp1.zext(104))
+              .trunc(52);
+    } else {
+      KnownMul = KnownBits::mulhu(Known52BitsOfOp0, Known52BitsOfOp1);
+    }
+    KnownMul = KnownMul.zext(64);
+
+    // C1 * C2 + Z --> C3 + Z
+    if (KnownMul.isConstant()) {
+      SDValue C = TLO.DAG.getConstant(KnownMul.getConstant(), SDLoc(Op0), VT);
+      return TLO.CombineTo(Op,
+                           TLO.DAG.getNode(ISD::ADD, SDLoc(Op), VT, C, Op2));
+    }
 
     // TODO: Compute the known bits for VPMADD52L/VPMADD52H.
     break;
diff --git a/llvm/test/CodeGen/X86/combine-vpmadd52.ll b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
index fd295ea31c55c..1e075bfe12a31 100644
--- a/llvm/test/CodeGen/X86/combine-vpmadd52.ll
+++ b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
@@ -183,3 +183,110 @@ define <2 x i64> @test_vpmadd52l_mul_zero_scalar(<2 x i64> %x0, <2 x i64> %x1) {
   %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> <i64 0, i64 123>, <2 x i64> %x1)
   ret <2 x i64> %1
 }
+
+define <2 x i64> @test_vpmadd52l_mul_lo52_zero(<2 x i64> %x0) {
+  ; (1 << 51) * (1 << 1) -> 1 << 52 -> low 52 bits are zeroes
+; CHECK-LABEL: test_vpmadd52l_mul_lo52_zero:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    retq
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 2251799813685248), <2 x i64> splat (i64 2))
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52h_mul_hi52_zero(<2 x i64> %x0) {
+  ; (1 << 25) * (1 << 26) = 1 << 51 -> high 52 bits are zeroes
+; CHECK-LABEL: test_vpmadd52h_mul_hi52_zero:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    retq
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 33554432), <2 x i64> splat (i64 67108864))
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52l_mul_lo52_const(<2 x i64> %x0) {
+; AVX512-LABEL: test_vpmadd52l_mul_lo52_const:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm0
+; AVX512-NEXT:    retq
+;
+; AVX-LABEL: test_vpmadd52l_mul_lo52_const:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; AVX-NEXT:    retq
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 123), <2 x i64> splat (i64 456))
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52h_mul_hi52_const(<2 x i64> %x0) {
+  ; (1 << 51) * (1 << 51) -> 1 << 102 -> the high 52 bits is 1 << 50
+; AVX512-LABEL: test_vpmadd52h_mul_hi52_const:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm0
+; AVX512-NEXT:    retq
+;
+; AVX-LABEL: test_vpmadd52h_mul_hi52_const:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; AVX-NEXT:    retq
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 2251799813685248), <2 x i64> splat (i64 2251799813685248))
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52l_mul_lo52_mask(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
+; CHECK-LABEL: test_vpmadd52l_mul_lo52_mask:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    retq
+  %and1 = and <2 x i64> %x0, splat (i64 1073741824) ; 1LL << 30
+  %and2 = and <2 x i64> %x1, splat (i64 1073741824) ; 1LL << 30
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52h_mul_hi52_mask(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
+; CHECK-LABEL: test_vpmadd52h_mul_hi52_mask:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    retq
+  %and1 = lshr <2 x i64> %x0, splat (i64 40)
+  %and2 = lshr <2 x i64> %x1, splat (i64 40)
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52l_mul_lo52_mask_negative(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
+; AVX512-LABEL: test_vpmadd52l_mul_lo52_mask_negative:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm2
+; AVX512-NEXT:    vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm1
+; AVX512-NEXT:    vpmadd52luq %xmm1, %xmm2, %xmm0
+; AVX512-NEXT:    retq
+;
+; AVX-LABEL: test_vpmadd52l_mul_lo52_mask_negative:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm2
+; AVX-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
+; AVX-NEXT:    {vex} vpmadd52luq %xmm1, %xmm2, %xmm0
+; AVX-NEXT:    retq
+  %and1 = and <2 x i64> %x0, splat (i64 2097152) ; 1LL << 21
+  %and2 = and <2 x i64> %x1, splat (i64 1073741824) ; 1LL << 30
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
+  ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52h_mul_hi52_negative(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
+; AVX512-LABEL: test_vpmadd52h_mul_hi52_negative:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpsrlq $30, %xmm0, %xmm2
+; AVX512-NEXT:    vpsrlq $43, %xmm1, %xmm1
+; AVX512-NEXT:    vpmadd52huq %xmm1, %xmm2, %xmm0
+; AVX512-NEXT:    retq
+;
+; AVX-LABEL: test_vpmadd52h_mul_hi52_negative:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpsrlq $30, %xmm0, %xmm2
+; AVX-NEXT:    vpsrlq $43, %xmm1, %xmm1
+; AVX-NEXT:    {vex} vpmadd52huq %xmm1, %xmm2, %xmm0
+; AVX-NEXT:    retq
+  %and1 = lshr <2 x i64> %x0, splat (i64 30)
+  %and2 = lshr <2 x i64> %x1, splat (i64 43)
+  %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
+  ret <2 x i64> %1
+}

TLO.DAG.getNode(ISD::ADD, SDLoc(Op), VT, C, Op2));
}

// TODO: Compute the known bits for VPMADD52L/VPMADD52H.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Easy to resolve now.

@XChy XChy changed the title [X86] Fold C1 * C2 + Z --> C3 + Z for vpmadd52l/vpmadd52h [X86] Fold X * Y + Z --> C3 + Z for vpmadd52l/vpmadd52h Sep 1, 2025
@XChy XChy changed the title [X86] Fold X * Y + Z --> C3 + Z for vpmadd52l/vpmadd52h [X86] Fold X * Y + Z --> C + Z for vpmadd52l/vpmadd52h Sep 1, 2025
Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with one minor - cheers

}

define <2 x i64> @test_vpmadd52h_mul_hi52_zero(<2 x i64> %x0) {
; (1 << 25) * (1 << 26) = 1 << 51 -> high 52 bits are zeroes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(style) move these manual comments above the define to avoid the update script potentially mangling it

@XChy XChy force-pushed the perf/simplify-intermediate-VPMADD52 branch from 56cc479 to 880d0ca Compare September 1, 2025 16:38
@XChy XChy enabled auto-merge (squash) September 1, 2025 16:38
@XChy XChy merged commit c241eb3 into llvm:main Sep 1, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants